
from sklearn.preprocessing import label
from nlu_model.cls.model_pytorch.model_train import TrainModelPipeline
from nlu_model.cls.thews.util import data_loader, pretrain_w2v, word2idx

if __name__ == "__main__":
    x_train, x_test, y_train, y_test, label_list = data_loader()
    label_list = label_list.TITLE_DICT

    embedding_weights, imdb_w2v, word2idx_dic = pretrain_w2v(x_train)
    # print(x_train[0])
    x_train = word2idx(x_train, word2idx_dic, padding = 20)
    x_test = word2idx(x_test, word2idx_dic, padding = 20)
    # print(x_train[0])

    textCNN_param = {
        "model_name": "TEXTCNN",
        'vocab_size': len(word2idx_dic),
        'embed_dim': 300,
        'class_num': len(label_list),
        "kernel_num": 16,
        "kernel_size": [1, 2, 3, 4, 5],
        "dropout": 0.5,
        "pre_word_embeds": embedding_weights,
        "learning_rate": 0.001
    }
    # TEXTCNN--THEWS--ACC--0.6663793103448276

    textRNN_param = {
        "model_name": "TEXTRNN",
        'vocab_size': len(word2idx_dic),
        'embed_dim': 300,
        'class_num': len(label_list),
        "kernel_num": 16,
        "kernel_size": [1, 2, 3, 4, 5],
        "dropout": 0.5,
        "pre_word_embeds": embedding_weights,
        "learning_rate": 0.001,
        "lstm_hidden": 128,
        "lstm_num_layers": 2
    }
    # TEXTRNN--THEWS--ACC--0.5997938530734632

    textRCNN_param = {
        "model_name": "TEXTRCNN",
        'vocab_size': len(word2idx_dic),
        'embed_dim': 300,
        'class_num': len(label_list),
        "kernel_num": 16,
        "kernel_size": [1, 2, 3, 4, 5],
        "dropout": 0.5,
        "pre_word_embeds": embedding_weights,
        "learning_rate": 0.001,
        "lstm_hidden": 128,
        "lstm_num_layers": 2,
        "pad_size":20
    }
    # TEXTRNN--THEWS--ACC--0.6103260869565217

    transformer_cls_param = {
        "model_name": "TRANSFORMERCLS",
        'vocab_size': len(word2idx_dic),
        'embed_dim': 300,
        'class_num': len(label_list),
        "kernel_num": 16,
        "kernel_size": [1, 2, 3, 4, 5],
        "dropout": 0.5,
        "pre_word_embeds": embedding_weights,
        "learning_rate": 0.001,
        "lstm_hidden": 128,
        "lstm_num_layers": 2,
        "pad_size":20
    }
    # TEXTRNN--THEWS--ACC--0.6103260869565217

    train_config = {
        "MODEL_CONF": textCNN_param,
        "batch_size": 64,
        "epoch":3
    }
    # print(label_list)
    # print(y_test)
    train_model_pipeline = TrainModelPipeline(train_config)
    train_model_pipeline.call_train(x_train[:10000], y_train[:10000])
    train_model_pipeline.call_evaluate(x_train, y_train)
    train_model_pipeline.call_evaluate(x_test, y_test)